import jittor as jt
from jittor import nn
from jittor.models import resnet
jt.flags.use_cuda = 1

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        jt.init.gauss_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        jt.init.gauss_(m.weight, 1.0, 0.02)
        jt.init.constant_(m.bias, 0.0)

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv(in_size, out_size, 3, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm(out_size, 0.8))
        layers.append(nn.Leaky_relu(0.2))
        self.model = nn.Sequential(*layers)

    def execute(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetUp, self).__init__()
        self.model = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv(in_size, out_size, 3, stride=1, padding=1, bias=False), nn.BatchNorm(out_size, 0.8), nn.Relu())

    def execute(self, x, skip_input):
        x = self.model(x)
        x = jt.contrib.concat([x, skip_input], dim=1)
        return x

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        (channels, self.h, self.w) = img_shape
        self.fc = nn.Linear(latent_dim, (self.h * self.w))
        self.down1 = UNetDown((channels + 1), 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512, normalize=False)
        self.up1 = UNetUp(512, 512)
        self.up2 = UNetUp(1024, 512)
        self.up3 = UNetUp(1024, 512)
        self.up4 = UNetUp(1024, 256)
        self.up5 = UNetUp(512, 128)
        self.up6 = UNetUp(256, 64)
        self.final = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv(128, channels, 3, stride=1, padding=1), nn.Tanh())

        for m in self.modules():
            weights_init_normal(m)

    def execute(self, x, z):
        z = self.fc(z)
        z = jt.reshape(z, [z.shape[0], 1, self.h, self.w])
        d1 = self.down1(jt.contrib.concat([x, z], dim=1))
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        u1 = self.up1(d7, d6)
        u2 = self.up2(u1, d5)
        u3 = self.up3(u2, d4)
        u4 = self.up4(u3, d3)
        u5 = self.up5(u4, d2)
        u6 = self.up6(u5, d1)
        return self.final(u6)

class Encoder(nn.Module):

    def __init__(self, latent_dim, input_shape):
        super(Encoder, self).__init__()
        resnet18_model = resnet.Resnet18()
        self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:(- 3)])
        self.pooling = nn.Pool(kernel_size=8, stride=8, padding=0, op='mean')
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        for m in self.modules():
            weights_init_normal(m)

    def execute(self, img):
        out = self.feature_extractor(img)
        out = self.pooling(out)
        out = jt.reshape(out, [out.shape[0], (- 1)])
        mu = self.fc_mu(out)
        logvar = self.fc_logvar(out)
        return mu, logvar

class MultiDiscriminator(nn.Module):

    def __init__(self, input_shape):
        super(MultiDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            'Returns downsampling layers of each discriminator block'
            layers = [nn.Conv(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm(out_filters, 0.8))
            layers.append(nn.Leaky_relu(0.2))
            return layers
        (channels, _, _) = input_shape
        self.disc_0 = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False), 
            *discriminator_block(64, 128), 
            *discriminator_block(128, 256), 
            *discriminator_block(256, 512), 
            nn.Conv(512, 1, 3, padding=1)
        )
        self.disc_1 = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False), 
            *discriminator_block(64, 128), 
            *discriminator_block(128, 256), 
            *discriminator_block(256, 512), 
            nn.Conv(512, 1, 3, padding=1)
        )
        self.disc_2 = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False), 
            *discriminator_block(64, 128), 
            *discriminator_block(128, 256), 
            *discriminator_block(256, 512), 
            nn.Conv(512, 1, 3, padding=1)
        )
        for m in self.modules():
            weights_init_normal(m)
        
    def compute_loss(self, x, gt):
        'Computes the MSE between model output and scalar gt'
        loss = sum([jt.mean((out - gt).sqr()) for out in self.execute(x)])
        return loss

    def execute(self, x):
        outputs = []
        outputs.append(self.disc_0(x))
        outputs.append(self.disc_1(x))
        outputs.append(self.disc_2(x))
        return outputs